#include "BillboardGeometryShader.h"

BillboardGeometryShader::BillboardGeometryShader(ID3D11Device* device, HWND hwnd) : BowerShader(device, hwnd)
{
	initShader(L"bypass_vs.cso", L"billboard_gs.cso", L"billboard_ps.cso");
}

BillboardGeometryShader::~BillboardGeometryShader()
{
	//Release all the buffers
	if (shadowMatrixBuffer)
	{
		shadowMatrixBuffer->Release();
		shadowMatrixBuffer = 0;
	}

	if (billboardBuffer)
	{
		billboardBuffer->Release();
		billboardBuffer = 0;
	}

	if (billboardPositionsBuffer)
	{
		billboardPositionsBuffer->Release();
		billboardPositionsBuffer = 0;
	}

	if (lightBuffer)
	{
		lightBuffer->Release();
		lightBuffer = 0;
	}

	if (matrixBuffer)
	{
		matrixBuffer->Release();
		matrixBuffer = 0;
	}

	//Release the samplers
	if (sampleState)
	{
		sampleState->Release();
		sampleState = 0;
	}

	if (sampleStateShadow)
	{
		sampleStateShadow->Release();
		sampleStateShadow = 0;
	}

	if (layout)
	{
		layout->Release();
		layout = 0;
	}

	//Release base shader components
	BaseShader::~BaseShader();
}

void BillboardGeometryShader::initShader(const wchar_t* vsFilename, const wchar_t* psFilename)
{
	// Load (+ compile) shader files
	loadVertexShader(vsFilename);
	loadPixelShader(psFilename);

	//Setup the shadow matrix buffer
	D3D11_BUFFER_DESC shadowMatrixBufferDesc;
	shadowMatrixBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
	shadowMatrixBufferDesc.ByteWidth = sizeof(ShadowMatrixBufferType);
	shadowMatrixBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
	shadowMatrixBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
	shadowMatrixBufferDesc.MiscFlags = 0;
	shadowMatrixBufferDesc.StructureByteStride = 0;
	renderer->CreateBuffer(&shadowMatrixBufferDesc, NULL, &shadowMatrixBuffer);

	//Setup the light matrix buffer
	D3D11_BUFFER_DESC lightBufferDesc;
	lightBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
	lightBufferDesc.ByteWidth = sizeof(LightBufferType);
	lightBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
	lightBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
	lightBufferDesc.MiscFlags = 0;
	lightBufferDesc.StructureByteStride = 0;
	renderer->CreateBuffer(&lightBufferDesc, NULL, &lightBuffer);

	//Setup the billboard buffer
	D3D11_BUFFER_DESC billboardBufferDesc;
	billboardBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
	billboardBufferDesc.ByteWidth = sizeof(BillboardBufferType);
	billboardBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
	billboardBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
	billboardBufferDesc.MiscFlags = 0;
	billboardBufferDesc.StructureByteStride = 0;
	renderer->CreateBuffer(&billboardBufferDesc, NULL, &billboardBuffer);

	//Setup the billboard positions buffer
	D3D11_BUFFER_DESC billboardPositionsBufferDesc;
	billboardPositionsBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
	billboardPositionsBufferDesc.ByteWidth = sizeof(BillboardPositionsBufferType);
	billboardPositionsBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
	billboardPositionsBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
	billboardPositionsBufferDesc.MiscFlags = 0;
	billboardPositionsBufferDesc.StructureByteStride = 0;
	renderer->CreateBuffer(&billboardPositionsBufferDesc, NULL, &billboardPositionsBuffer);

	//Setup the base sampler
	D3D11_SAMPLER_DESC samplerDesc;
	samplerDesc.Filter = D3D11_FILTER_MIN_MAG_MIP_POINT;
	samplerDesc.AddressU = D3D11_TEXTURE_ADDRESS_CLAMP;
	samplerDesc.AddressV = D3D11_TEXTURE_ADDRESS_CLAMP;
	samplerDesc.AddressW = D3D11_TEXTURE_ADDRESS_CLAMP;
	samplerDesc.MipLODBias = 0.0f;
	samplerDesc.MaxAnisotropy = 1;
	samplerDesc.ComparisonFunc = D3D11_COMPARISON_ALWAYS;
	samplerDesc.MinLOD = 0;
	samplerDesc.MaxLOD = D3D11_FLOAT32_MAX;
	renderer->CreateSamplerState(&samplerDesc, &sampleState);

	//Setup the sampler shadow
	D3D11_SAMPLER_DESC shadowSamplerDesc;
	shadowSamplerDesc.Filter = D3D11_FILTER_MIN_MAG_MIP_POINT;
	shadowSamplerDesc.AddressU = D3D11_TEXTURE_ADDRESS_BORDER;
	shadowSamplerDesc.AddressV = D3D11_TEXTURE_ADDRESS_BORDER;
	shadowSamplerDesc.AddressW = D3D11_TEXTURE_ADDRESS_BORDER;
	shadowSamplerDesc.BorderColor[0] = 1.0f;
	shadowSamplerDesc.BorderColor[1] = 1.0f;
	shadowSamplerDesc.BorderColor[2] = 1.0f;
	shadowSamplerDesc.BorderColor[3] = 1.0f;
	renderer->CreateSamplerState(&shadowSamplerDesc, &sampleStateShadow);
}

void BillboardGeometryShader::initShader(const wchar_t* vsFilename, const wchar_t* gsFilename, const wchar_t* psFilename)
{
	initShader(vsFilename, psFilename);
	loadGeometryShader(gsFilename);
}

void BillboardGeometryShader::SetPositions(BillboardPositionsBufferType* ptr, GeometryType type)
{
	if (type == Tree)	//Set the vertex positions for a tree
	{
		ptr->positions[0] = XMFLOAT4(0.0f, 5.0f, 0.0f, 0.0f);
		ptr->positions[1] = XMFLOAT4(0.0f, 0.0f, 0.0f, 0.0f);
		ptr->positions[2] = XMFLOAT4(5.0f, 5.0f, 0.0f, 0.0f);
		ptr->positions[3] = XMFLOAT4(5.0f, 0.0f, 0.0f, 0.0f);
	}

	else    //We must be a blade of grass
	{
		ptr->positions[0] = XMFLOAT4(0.0f, 1.75f, 0.0f, 0.0f);
		ptr->positions[1] = XMFLOAT4(0.0f, 0.0f, 0.0f, 0.0f);
		ptr->positions[2] = XMFLOAT4(1.75f, 1.75f, 0.0f, 0.0f);
		ptr->positions[3] = XMFLOAT4(1.75f, 0.0f, 0.0f, 0.0f);
	}
}

void BillboardGeometryShader::setShaderParameters(ID3D11DeviceContext* deviceContext, const XMMATRIX &worldMatrix, const XMMATRIX &viewMatrix, const XMMATRIX &projectionMatrix, XMFLOAT3 cameraPosition, GeometryType type, ID3D11ShaderResourceView* billboardTexture, ID3D11ShaderResourceView* sunShadowMap, ID3D11ShaderResourceView* spotShadowMap, ID3D11ShaderResourceView* hillShadowMaps[6], float constant, float linear, float quadratic)
{
	D3D11_MAPPED_SUBRESOURCE mappedResource;

	XMMATRIX tworld = XMMatrixTranspose(worldMatrix);
	XMMATRIX tview = XMMatrixTranspose(viewMatrix);
	XMMATRIX tproj = XMMatrixTranspose(projectionMatrix);

	//Map the billboard buffer
	deviceContext->Map(billboardBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mappedResource);
	BillboardBufferType* geometryPtr = (BillboardBufferType*)mappedResource.pData;
	geometryPtr->cameraPosition = XMFLOAT4(cameraPosition.x, cameraPosition.y, cameraPosition.z, 1.0f);
	deviceContext->Unmap(billboardBuffer, 0);
	deviceContext->GSSetConstantBuffers(1, 1, &billboardBuffer);

	//Map the positions buffer
	deviceContext->Map(billboardPositionsBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mappedResource);
	BillboardPositionsBufferType* positionsPtr = (BillboardPositionsBufferType*)mappedResource.pData;
	SetPositions(positionsPtr, type);
	deviceContext->Unmap(billboardPositionsBuffer, 0);
	deviceContext->GSSetConstantBuffers(2, 1, &billboardPositionsBuffer);

	//Map the light buffer
	deviceContext->Map(lightBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mappedResource);
	LightBufferType* lightPtr = (LightBufferType*)mappedResource.pData;
	lightPtr->sunAmbient = _SunLight->getAmbientColour();
	lightPtr->sunDiffuse = _SunLight->getDiffuseColour();
	lightPtr->sunDirection = _SunLight->getDirection();
	lightPtr->padding = 0.0f;
	lightPtr->spotDiffuse = _SpotLight->getDiffuseColour();
	lightPtr->spotPosition = _SpotLight->getPosition();
	lightPtr->constantFactor = constant;
	lightPtr->linearFactor = linear;
	lightPtr->quadraticFactor = quadratic;
	lightPtr->paddingTwo = XMFLOAT2(0.0f, 0.0f);
	lightPtr->spotDirection = _SpotLight->getDirection();
	lightPtr->paddingThree = 0.0f;
	lightPtr->hillDiffuse = _HillLight->getDiffuseColour();
	lightPtr->hillPosition = _HillLight->getPosition();
	lightPtr->paddingFour = 0.0f;
	deviceContext->Unmap(lightBuffer, 0);
	deviceContext->PSSetConstantBuffers(0, 1, &lightBuffer);

	//Map the shadow matrix buffer
	deviceContext->Map(shadowMatrixBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mappedResource);
	ShadowMatrixBufferType* shadowMatrixPtr = (ShadowMatrixBufferType*)mappedResource.pData;
	shadowMatrixPtr->world = tworld;
	shadowMatrixPtr->view = tview;
	shadowMatrixPtr->projection = tproj;
	shadowMatrixPtr->sunLightView = XMMatrixTranspose(_SunLight->getViewMatrix());
	shadowMatrixPtr->sunLightProjection = XMMatrixTranspose(_SunLight->getOrthoMatrix());

	//Get the view matrix for the spot light
	XMMATRIX spotViewMatrix;
	if (_SpotLight->getDirection().x == 0 && _SpotLight->getDirection().z == 0)
	{
		spotViewMatrix = GetYAxisViewMatrix(_SpotLight);

		if (_SpotLight->getDirection().y < 0.0f)
		{
			spotViewMatrix = -spotViewMatrix;
		}
	}

	else
	{
		_SpotLight->generateViewMatrix();
		spotViewMatrix = _SpotLight->getViewMatrix();
	}

	//Set the values into the mapped pointer
	shadowMatrixPtr->spotLightView = XMMatrixTranspose(spotViewMatrix);
	shadowMatrixPtr->spotLightProjection = XMMatrixTranspose(_SpotLight->getProjectionMatrix());

	//Get the view matrices for the point light
	XMFLOAT3 lightDirections[6] =
	{
		XMFLOAT3(0.0f, 1.0f, 0.0f),
		XMFLOAT3(0.0f, -1.0f, 0.0f),
		XMFLOAT3(1.0f, 0.0f, 0.0f),
		XMFLOAT3(-1.0f, 0.0f, 0.0f),
		XMFLOAT3(0.0f, 0.0f, 1.0f),
		XMFLOAT3(0.0f, 0.0f, -1.0f)
	};

	for (int i = 0; i < 6; i++)
	{
		_HillLight->setDirection(lightDirections[i].x, lightDirections[i].y, lightDirections[i].z);
		XMMATRIX hillViewMatrix;

		if (_HillLight->getDirection().x == 0 && _HillLight->getDirection().z == 0)
		{
			hillViewMatrix = GetYAxisViewMatrix(_HillLight);

			if (_HillLight->getDirection().y < 0.0f)
			{
				hillViewMatrix = -hillViewMatrix;
			}
		}

		else
		{
			_HillLight->generateViewMatrix();
			hillViewMatrix = _HillLight->getViewMatrix();
		}

		XMMATRIX lightViewMatrix = XMMatrixTranspose(hillViewMatrix);
		XMMATRIX lightProjectionMatrix = XMMatrixTranspose(_HillLight->getProjectionMatrix());

		//Set the values into the mapped pointer
		shadowMatrixPtr->hillLightViews[i] = lightViewMatrix;
		shadowMatrixPtr->hillLightProjections[i] = lightProjectionMatrix;
	}

	deviceContext->Unmap(shadowMatrixBuffer, 0);
	deviceContext->GSSetConstantBuffers(0, 1, &shadowMatrixBuffer);
	deviceContext->PSSetConstantBuffers(1, 1, &shadowMatrixBuffer);

	//Set the pixel shader textures and samplers
	deviceContext->PSSetShaderResources(0, 1, &billboardTexture);
	deviceContext->PSSetShaderResources(1, 1, &sunShadowMap);
	deviceContext->PSSetShaderResources(2, 1, &spotShadowMap);
	deviceContext->PSSetShaderResources(3, 1, hillShadowMaps);
	deviceContext->PSSetSamplers(0, 1, &sampleState);
	deviceContext->PSSetSamplers(1, 1, &sampleStateShadow);
}
